# Copyright (c) 2024-present AI-Labs

from fastapi import APIRouter, Request

from datetime import datetime

import os
import uuid
import numpy as np
from PIL import Image
import base64
from io import BytesIO

from leffa.transform import LeffaTransform
from leffa.model import LeffaModel
from leffa.inference import LeffaInference
from leffa_utils.garment_agnostic_mask_predictor import AutoMasker
from leffa_utils.densepose_predictor import DensePosePredictor
from leffa_utils.utils import resize_and_center, get_agnostic_mask_hd, get_agnostic_mask_dc
from preprocess.humanparsing.run_parsing import Parsing
from preprocess.openpose.run_openpose import OpenPose

from configs import config


leffa_model_path=config.service

"""
定义Leffa预测类
"""
class LeffaPredictor(object):
    """
    对象初始化，主要用于指定模型权重路径
    """
    def __init__(self):
        self.mask_predictor = AutoMasker(
            densepose_path=f"{config.service.leffa.model_path.leffa}/densepose",
            schp_path=f"{config.service.leffa.model_path.leffa}/schp",
            device='cuda', 
        )

        self.densepose_predictor = DensePosePredictor(
            config_path=f"{config.service.leffa.model_path.leffa}/densepose/densepose_rcnn_R_50_FPN_s1x.yaml",
            weights_path=f"{config.service.leffa.model_path.leffa}/densepose/model_final_162be9.pkl",
        )

        self.parsing = Parsing(
            atr_path=f"{config.service.leffa.model_path.leffa}/humanparsing/parsing_atr.onnx",
            lip_path=f"{config.service.leffa.model_path.leffa}/humanparsing/parsing_lip.onnx",
        )

        self.openpose = OpenPose(
            body_model_path=f"{config.service.leffa.model_path.leffa}/openpose/body_pose_model.pth",
        )

        vt_model_hd = LeffaModel(
            pretrained_model_name_or_path=f"{config.service.leffa.model_path.leffa}/stable-diffusion-inpainting",
            pretrained_model=f"{config.service.leffa.model_path.leffa}/virtual_tryon.pth",
            dtype="float16",
        )
        self.vt_inference_hd = LeffaInference(model=vt_model_hd)

        vt_model_dc = LeffaModel(
            pretrained_model_name_or_path=f"{config.service.leffa.model_path.leffa}/stable-diffusion-inpainting",
            pretrained_model=f"{config.service.leffa.model_path.leffa}/virtual_tryon_dc.pth",
            dtype="float16",
        )
        self.vt_inference_dc = LeffaInference(model=vt_model_dc)

        pt_model = LeffaModel(
            pretrained_model_name_or_path=f"{config.service.leffa.model_path.leffa}/stable-diffusion-xl-1.0-inpainting-0.1",
            pretrained_model=f"{config.service.leffa.model_path.leffa}/pose_transfer.pth",
            dtype="float16",
        )
        self.pt_inference = LeffaInference(model=pt_model)

    """
    Leffa预测入口，根据模特图和服装图生成试穿效果图
    """
    def leffa_predict(
        self,
        src_image_path, # 模特
        ref_image_path, # 服装
        control_type = "virtual_tryon",
        ref_acceleration=False,
        step=30,
        scale=2.5,
        seed=42,
        vt_model_type="viton_hd",
        vt_garment_type="upper_body",
        vt_repaint=False
    ):
        assert control_type in ["virtual_tryon", "pose_transfer"], "Invalid control type: {}".format(control_type)

        src_image = Image.open(src_image_path)
        ref_image = Image.open(ref_image_path)
        src_image = resize_and_center(src_image, 768, 1024)
        ref_image = resize_and_center(ref_image, 768, 1024)

        src_image_array = np.array(src_image)

        # Mask
        mask = None
        if control_type == "virtual_tryon":
            src_image = src_image.convert("RGB")
            model_parse, _ = self.parsing(src_image.resize((384, 512)))
            keypoints = self.openpose(src_image.resize((384, 512)))
            if vt_model_type == "viton_hd":
                mask = get_agnostic_mask_hd(
                    model_parse, keypoints, vt_garment_type)
            elif vt_model_type == "dress_code":
                mask = get_agnostic_mask_dc(
                    model_parse, keypoints, vt_garment_type)
            mask = mask.resize((768, 1024))
            # garment_type_hd = "upper" if vt_garment_type in [
            #     "upper_body", "dresses"] else "lower"
            # mask = self.mask_predictor(src_image, garment_type_hd)["mask"]
        elif control_type == "pose_transfer":
            mask = Image.fromarray(np.ones_like(src_image_array) * 255)

        # DensePose
        if control_type == "virtual_tryon":
            if vt_model_type == "viton_hd":
                src_image_seg_array = self.densepose_predictor.predict_seg(
                    src_image_array)[:, :, ::-1]
                src_image_seg = Image.fromarray(src_image_seg_array)
                densepose = src_image_seg
            elif vt_model_type == "dress_code":
                src_image_iuv_array = self.densepose_predictor.predict_iuv(
                    src_image_array)
                src_image_seg_array = src_image_iuv_array[:, :, 0:1]
                src_image_seg_array = np.concatenate(
                    [src_image_seg_array] * 3, axis=-1)
                src_image_seg = Image.fromarray(src_image_seg_array)
                densepose = src_image_seg
        elif control_type == "pose_transfer":
            src_image_iuv_array = self.densepose_predictor.predict_iuv(
                src_image_array)[:, :, ::-1]
            src_image_iuv = Image.fromarray(src_image_iuv_array)
            densepose = src_image_iuv

        # Leffa
        transform = LeffaTransform()

        data = {
            "src_image": [src_image],
            "ref_image": [ref_image],
            "mask": [mask],
            "densepose": [densepose],
        }
        data = transform(data)
        if control_type == "virtual_tryon":
            if vt_model_type == "viton_hd":
                inference = self.vt_inference_hd
            elif vt_model_type == "dress_code":
                inference = self.vt_inference_dc
        elif control_type == "pose_transfer":
            inference = self.pt_inference
        output = inference(
            data,
            ref_acceleration=ref_acceleration,
            num_inference_steps=step,
            guidance_scale=scale,
            seed=seed,
            repaint=vt_repaint,)
        images = output["generated_image"]
        return images

    def leffa_predict_vt(self, src_image_path, ref_image_path, step, scale, seed):
        return self.leffa_predict(src_image_path, ref_image_path, step=step, scale=scale, seed=seed)


"""
定义路由信息
"""
router = APIRouter(
    prefix='/try_on/leffa',
    tags = ['虚拟试穿']
)

leffa_predictor = LeffaPredictor()


"""
PIL图像转换成Base64编码字符串
"""
def image_to_base64(pil_image: Image)->str:
    buffered = BytesIO()
    pil_image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')


"""
虚拟试穿服务接口
"""
@router.post("/v1")
async def tryon(request: Request):
    # 获取用户请求数据
    data = await request.json()

    # 将用户上传的模特图、服装图保存到本地文件
    localdir = f"statics/upload/{datetime.now().strftime('%Y-%m-%d')}/{uuid.uuid4()}"
    os.makedirs(localdir, exist_ok=True)

    image_b64 = data["cloth_image"].split(',')[1]
    with open(f"{localdir}/cloth_image.png", "wb") as f:
        f.write(base64.b64decode(image_b64))

    image_b64 = data["cloth_mask"].split(',')[1]
    with open(f"{localdir}/cloth_mask.png", "wb") as f:
        f.write(base64.b64decode(image_b64))

    image_b64 = data["person_image"].split(',')[1]
    with open(f"{localdir}/person_image.png", "wb") as f:
        f.write(base64.b64decode(image_b64))

    image_b64 = data["person_mask"].split(',')[1]
    with open(f"{localdir}/person_mask.png", "wb") as f:
        f.write(base64.b64decode(image_b64))

    # 基于本地文件进行虚拟试穿推理预测，生成试穿效果图
    images = leffa_predictor.leffa_predict_vt(src_image_path=f"{localdir}/person_image.png",
                                              ref_image_path=f"{localdir}/cloth_image.png", 
                                              step=int(data['num_inference_steps']) if data['num_inference_steps'] else 30,
                                              scale=float(data['guidance_scale']) if data['guidance_scale'] else 2.5,
                                              seed=int(data['seed']) if data['seed'] else 42
                                             )
    base64s = []
    for image in images:
        base64s.append(f"data:image;base64,{image_to_base64(image)}")
    return {"images": base64s}
